{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h2q27gKz1H20"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "TUfAcER1oUS6"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gb7qyhNL1yWt"
      },
      "source": [
        "# BERT Question Answer with TensorFlow Lite Model Maker"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fw5Y7snSuG51"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/tutorials/model_maker_question_answer\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/tutorials/model_maker_question_answer.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sr3q-gvm3cI8"
      },
      "source": [
        "The [TensorFlow Lite Model Maker library](https://www.tensorflow.org/lite/guide/model_maker) simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.\n",
        "\n",
        "This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used question answer model for question answer task."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UxEHFTk755qw"
      },
      "source": [
        "# Introduction to BERT Question Answer Task"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cFbKTCF25-SG"
      },
      "source": [
        "The supported task in this library is extractive question answer task, which means given a passage and a question, the answer is the span in the passage. The image below shows an example for question answer.\n",
        "\n",
        "\n",
        "\u003cp align=\"center\"\u003e\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_squad_showcase.png\"  width=\"500\"\u003e\u003c/p\u003e\n",
        "\n",
        "\u003cp align=\"center\"\u003e\n",
        "    \u003cem\u003eAnswers are spans in the passage (image credit: \u003ca href=\"https://rajpurkar.github.io/mlx/qa-and-squad/\"\u003eSQuAD blog\u003c/a\u003e) \u003c/em\u003e\n",
        "\u003c/p\u003e\n",
        "\n",
        "As for the model of question answer task, the inputs should be the passage and question pair that are already preprocessed, the outputs should be the start logits and end logits for each token in the passage.\n",
        "The size of input could be set and adjusted according to the length of passage and question."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gb7P4WQta8Ub"
      },
      "source": [
        "## End-to-End Overview\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w7cIHjIfbDlG"
      },
      "source": [
        "The following code snippet demonstrates how to get the model within a few lines of code. The overall process includes 5 steps: (1) choose a model, (2) load data, (3) retrain the model, (4) evaluate, and (5) export it to TensorFlow Lite format."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xQPdlxZBYuZG"
      },
      "source": [
        "```python\n",
        "# Chooses a model specification that represents the model.\n",
        "spec = model_spec.get('mobilebert_qa')\n",
        "\n",
        "# Gets the training data and validation data.\n",
        "train_data = DataLoader.from_squad(train_data_path, spec, is_training=True)\n",
        "validation_data = DataLoader.from_squad(validation_data_path, spec, is_training=False)\n",
        "\n",
        "# Fine-tunes the model.\n",
        "model = question_answer.create(train_data, model_spec=spec)\n",
        "\n",
        "# Gets the evaluation result.\n",
        "metric = model.evaluate(validation_data)\n",
        "\n",
        "# Exports the model to the TensorFlow Lite format with metadata in the export directory.\n",
        "model.export(export_dir)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "exScAdvBbNEi"
      },
      "source": [
        "The following sections explain the code in more detail."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bcLF2PKkSbV3"
      },
      "source": [
        "## Prerequisites\n",
        "\n",
        "To run this example, install the required packages, including the Model Maker package from the [GitHub repo](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qhl8lqVamEty"
      },
      "outputs": [],
      "source": [
        "!pip install -q tflite-model-maker"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l6lRhVK9Q_0U"
      },
      "source": [
        "Import the required packages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XtxiUeZEiXpt"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "assert tf.__version__.startswith('2')\n",
        "\n",
        "from tflite_model_maker import model_spec\n",
        "from tflite_model_maker import question_answer\n",
        "from tflite_model_maker.config import ExportFormat\n",
        "from tflite_model_maker.question_answer import DataLoader"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l65ctmtW7_FF"
      },
      "source": [
        "The \"End-to-End Overview\" demonstrates a simple end-to-end example. The following sections walk through the example step by step to show more detail."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kJ_B8fMDOhMR"
      },
      "source": [
        "## Choose a model_spec that represents a model for question answer\n",
        "\n",
        "Each `model_spec` object represents a specific model for question answer. The Model Maker currently supports MobileBERT and BERT-Base models.\n",
        "\n",
        "Supported Model | Name of model_spec | Model Description\n",
        "--- | --- | ---\n",
        "[MobileBERT](https://arxiv.org/pdf/2004.02984.pdf)  | 'mobilebert_qa' | 4.3x smaller and 5.5x faster than BERT-Base while achieving competitive results, suitable for on-device scenario.\n",
        "[MobileBERT-SQuAD](https://arxiv.org/pdf/2004.02984.pdf)  | 'mobilebert_qa_squad' | Same model architecture as MobileBERT model and the initial model is already retrained on [SQuAD1.1](https://rajpurkar.github.io/SQuAD-explorer/).\n",
        "[BERT-Base](https://arxiv.org/pdf/1810.04805.pdf) | 'bert_qa' | Standard BERT model that widely used in NLP tasks.\n",
        "\n",
        "In this tutorial, [MobileBERT-SQuAD](https://arxiv.org/pdf/2004.02984.pdf) is used as an example. Since the model is already retrained on [SQuAD1.1](https://rajpurkar.github.io/SQuAD-explorer/), it could coverage faster for question answer task.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vEAWuZQ1PFiX"
      },
      "outputs": [],
      "source": [
        "spec = model_spec.get('mobilebert_qa_squad')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ygEncJxtl-nQ"
      },
      "source": [
        "## Load Input Data Specific to an On-device ML App and Preprocess the Data\n",
        "\n",
        "The [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) is a reading comprehension dataset containing over 650K question-answer-evidence triples. In this tutorial, you will use a subset of this dataset to learn how to use the Model Maker library.\n",
        "\n",
        "To load the data, convert the TriviaQA dataset to the [SQuAD1.1](https://rajpurkar.github.io/SQuAD-explorer/) format by running the [converter Python script](https://github.com/mandarjoshi90/triviaqa#miscellaneous) with `--sample_size=8000` and a set of `web` data. Modify the conversion code a little bit by:\n",
        "* Skipping the samples that couldn't find any answer in the context document;\n",
        "* Getting the original answer in the context without uppercase or lowercase.\n",
        "\n",
        "Download the archived version of the already converted dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7tOfUr2KlgpU"
      },
      "outputs": [],
      "source": [
        "train_data_path = tf.keras.utils.get_file(\n",
        "    fname='triviaqa-web-train-8000.json',\n",
        "    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json')\n",
        "validation_data_path = tf.keras.utils.get_file(\n",
        "    fname='triviaqa-verified-web-dev.json',\n",
        "    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UfZk8GNr_1nc"
      },
      "source": [
        "You can also train the MobileBERT model with your own dataset. If you are running this notebook on Colab, upload your data by using the left sidebar.\n",
        "\n",
        "\u003cimg src=\"https://storage.googleapis.com/download.tensorflow.org/models/tflite/screenshots/model_maker_question_answer.png\" alt=\"Upload File\" width=\"800\" hspace=\"100\"\u003e\n",
        "\n",
        "If you prefer not to upload your data to the cloud, you can also run the library offline by following the [guide](https://github.com/tensorflow/examples/tree/master/tensorflow_examples/lite/model_maker)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E051HBUM5owi"
      },
      "source": [
        "Use the `DataLoader.from_squad` method to load and preprocess the [SQuAD format](https://rajpurkar.github.io/SQuAD-explorer/) data according to a specific `model_spec`. You can use either SQuAD2.0 or SQuAD1.1 formats. Setting parameter `version_2_with_negative` as `True` means the formats is SQuAD2.0. Otherwise, the format is SQuAD1.1. By default, `version_2_with_negative` is `False`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I_fOlZsklmlL"
      },
      "outputs": [],
      "source": [
        "train_data = DataLoader.from_squad(train_data_path, spec, is_training=True)\n",
        "validation_data = DataLoader.from_squad(validation_data_path, spec, is_training=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AWuoensX4vDA"
      },
      "source": [
        "## Customize the TensorFlow Model\n",
        "\n",
        "Create a custom question answer model based on the loaded data. The `create` function comprises the following steps:\n",
        "\n",
        "1. Creates the model for question answer according to `model_spec`.\n",
        "2. Train the question answer model. The default epochs and the default batch size are set according to two variables `default_training_epochs` and `default_batch_size` in the `model_spec` object."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TvYSUuJY3QxR"
      },
      "outputs": [],
      "source": [
        "model = question_answer.create(train_data, model_spec=spec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0JKI-pNc8idH"
      },
      "source": [
        "Have a look at the detailed model structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gd7Hs8TF8n3H"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LP5FPk_tOxoZ"
      },
      "source": [
        "## Evaluate the Customized Model\n",
        "\n",
        "Evaluate the model on the validation data and get a dict of metrics including `f1` score and `exact match` etc. Note that metrics are different for SQuAD1.1 and SQuAD2.0."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A8c2ZQ0J3Riy"
      },
      "outputs": [],
      "source": [
        "model.evaluate(validation_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aeHoGAceO2xV"
      },
      "source": [
        "## Export to TensorFlow Lite Model\n",
        "\n",
        "Convert the trained model to TensorFlow Lite model format with [metadata](https://www.tensorflow.org/lite/convert/metadata) so that you can later use in an on-device ML application. The vocab file are embedded in metadata. The default TFLite filename is `model.tflite`.\n",
        "\n",
        "In many on-device ML application, the model size is an important factor. Therefore, it is recommended that you apply quantize the model to make it smaller and potentially run faster.\n",
        "The default post-training quantization technique is dynamic range quantization for the BERT and MobileBERT models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Im6wA9lK3TQB"
      },
      "outputs": [],
      "source": [
        "model.export(export_dir='.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w12kvDdHJIGH"
      },
      "source": [
        "You can use the TensorFlow Lite model file in the [bert_qa](https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/android) reference app using [BertQuestionAnswerer API](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer) in [TensorFlow Lite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview) by downloading it from the left sidebar on Colab."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VFnJPvq3VGh3"
      },
      "source": [
        "The allowed export formats can be one or a list of the following:\n",
        "\n",
        "*   `ExportFormat.TFLITE`\n",
        "*   `ExportFormat.VOCAB`\n",
        "*   `ExportFormat.SAVED_MODEL`\n",
        "\n",
        "By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the vocab file as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ro2hz4kXVImY"
      },
      "outputs": [],
      "source": [
        "model.export(export_dir='.', export_format=ExportFormat.VOCAB)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HZKYthlVrTos"
      },
      "source": [
        "You can also evaluate the tflite model with the `evaluate_tflite` method. This step is expected to take a long time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ochbq95ZrVFX"
      },
      "outputs": [],
      "source": [
        "model.evaluate_tflite('model.tflite', validation_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EoWiA_zX8rxE"
      },
      "source": [
        "## Advanced Usage\n",
        "\n",
        "The `create` function is the critical part of this library in which the `model_spec` parameter defines the model specification. The `BertQASpec` class is currently supported. There are 2 models: MobileBERT model, BERT-Base model. The `create` function comprises the following steps:\n",
        "\n",
        "1. Creates the model for question answer according to `model_spec`.\n",
        "2. Train the question answer model.\n",
        "\n",
        "This section describes several advanced topics, including adjusting the model, tuning the training hyperparameters etc."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mwtiksguDfhl"
      },
      "source": [
        "### Adjust the model\n",
        "\n",
        "You can adjust the model infrastructure like parameters `seq_len` and `query_len` in the `BertQASpec` class.\n",
        "\n",
        "Adjustable parameters for model:\n",
        "\n",
        "* `seq_len`: Length of the passage to feed into the model.\n",
        "* `query_len`: Length of the question to feed into the model.\n",
        "* `doc_stride`: The stride when doing a sliding window approach to take chunks of the documents.\n",
        "* `initializer_range`: The stdev of the truncated_normal_initializer for initializing all weight matrices.\n",
        "* `trainable`: Boolean, whether pre-trained layer is trainable.\n",
        "\n",
        "Adjustable parameters for training pipeline:\n",
        "\n",
        "* `model_dir`: The location of the model checkpoint files. If not set, temporary directory will be used.\n",
        "* `dropout_rate`: The rate for dropout.\n",
        "* `learning_rate`: The initial learning rate for Adam.\n",
        "* `predict_batch_size`: Batch size for prediction.\n",
        "* `tpu`: TPU address to connect to. Only used if using tpu.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAOd5_bzH9AQ"
      },
      "source": [
        "For example, you can train the model with a longer sequence length. If you change the model, you must first construct a new `model_spec`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e9WBN0UTQoMN"
      },
      "outputs": [],
      "source": [
        "new_spec = model_spec.get('mobilebert_qa')\n",
        "new_spec.seq_len = 512"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6LSTdghTP0Cv"
      },
      "source": [
        "The remaining steps are the same. Note that you must rerun both the `dataloader` and `create` parts as different model specs may have different preprocessing steps.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LvQuy7RSDir3"
      },
      "source": [
        "### Tune training hyperparameters\n",
        "You can also tune the training hyperparameters like `epochs` and `batch_size` to impact the model performance. For instance,\n",
        "\n",
        "*   `epochs`: more epochs could achieve better performance, but may lead to overfitting.\n",
        "*   `batch_size`: number of samples to use in one training step.\n",
        "\n",
        "For example, you can train with more epochs and with a bigger batch size like:\n",
        "\n",
        "```python\n",
        "model = question_answer.create(train_data, model_spec=spec, epochs=5, batch_size=64)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eq6B9lKMfhS6"
      },
      "source": [
        "### Change the Model Architecture\n",
        "\n",
        "You can change the base model your data trains on by changing the `model_spec`. For example, to change to the BERT-Base model, run:\n",
        "\n",
        "```python\n",
        "spec = model_spec.get('bert_qa')\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L2d7yycrgu6L"
      },
      "source": [
        "The remaining steps are the same."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wFQrDMXzOVoB"
      },
      "source": [
        "### Customize Post-training quantization on the TensorFlow Lite model\n",
        "\n",
        "[Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator inference speed, with a little degradation in model accuracy. Thus, it's widely used to optimize the model.\n",
        "\n",
        "Model Maker library applies a default post-training quantization techique when exporting the model. If you want to customize post-training quantization, Model Maker supports multiple post-training quantization options using [QuantizationConfig](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/config/QuantizationConfig) as well. Let's take float16 quantization as an instance. First, define the quantization config.\n",
        "\n",
        "```python\n",
        "config = QuantizationConfig.for_float16()\n",
        "```\n",
        "\n",
        "\n",
        "Then we export the TensorFlow Lite model with such configuration.\n",
        "\n",
        "```python\n",
        "model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wPVopCeB6LV6"
      },
      "source": [
        "# Read more\n",
        "\n",
        "You can read our [BERT Question and Answer](https://www.tensorflow.org/lite/examples/bert_qa/overview) example to learn technical details. For more information, please refer to:\n",
        "\n",
        "*   TensorFlow Lite Model Maker [guide](https://www.tensorflow.org/lite/guide/model_maker) and [API reference](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker).\n",
        "*   Task Library: [BertQuestionAnswerer](https://www.tensorflow.org/lite/inference_with_metadata/task_library/bert_question_answerer) for deployment.\n",
        "*   The end-to-end reference apps: [Android](https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/android) and [iOS](https://github.com/tensorflow/examples/tree/master/lite/examples/bert_qa/ios)."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Model Maker Question Answer Tutorial",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
